import pandas as pd
import talib
from backtesting import Backtest, Strategy

# Clean and preprocess data
def preprocess_data(data_path):
    data = pd.read_csv(data_path)
    # Clean column names
    data.columns = data.columns.str.strip().str.lower()
    # Drop unnamed columns
    data = data.drop(columns=[col for col in data.columns if 'unnamed' in col.lower()])
    # Ensure proper column mapping
    data = data.rename(columns={
        'open': 'Open',
        'high': 'High',
        'low': 'Low',
        'close': 'Close',
        'volume': 'Volume'
    })
    # Convert datetime column to proper format
    data['datetime'] = pd.to_datetime(data['datetime'])
    data = data.set_index('datetime')
    return data

# MomentumRejection Strategy
class MomentumRejection(Strategy):
    # Strategy parameters
    stoch_k_period = 14
    stoch_d_period = 3
    stoch_overbought = 80
    stoch_oversold = 20
    risk_per_trade = 0.01  # 1% risk per trade
    risk_reward_ratio = 2  # 1:2 risk-reward ratio

    def init(self):
        # Calculate Stochastic Oscillator
        self.stoch_k, self.stoch_d = self.I(talib.STOCH, self.data.High, self.data.Low, self.data.Close,
                                            fastk_period=self.stoch_k_period, slowk_period=self.stoch_d_period,
                                            slowd_period=self.stoch_d_period)
        # Calculate trendlines (using rolling highs/lows as proxies)
        self.uptrend_line = self.I(talib.MIN, self.data.Low, timeperiod=20)  # Higher lows for uptrend
        self.downtrend_line = self.I(talib.MAX, self.data.High, timeperiod=20)  # Lower highs for downtrend

    def next(self):
        # Current price and indicators
        price = self.data.Close[-1]
        stoch_k = self.stoch_k[-1]
        stoch_d = self.stoch_d[-1]
        uptrend_line = self.uptrend_line[-1]
        downtrend_line = self.downtrend_line[-1]

        # Risk management: Calculate position size
        position_size = self.risk_per_trade * self.equity / price

        # Entry logic for continuation pattern (uptrend)
        if price > uptrend_line and stoch_k < self.stoch_oversold and stoch_d < self.stoch_oversold:
            if stoch_k[-2] < stoch_d[-2] and stoch_k[-1] > stoch_d[-1]:  # Stochastic crossover confirmation
                self.buy(size=position_size)
                print(f"🌙 Moon Dev Buy Signal: Continuation Uptrend | Price: {price} | Stochastic: {stoch_k}, {stoch_d}")

        # Entry logic for continuation pattern (downtrend)
        elif price < downtrend_line and stoch_k > self.stoch_overbought and stoch_d > self.stoch_overbought:
            if stoch_d[-2] < stoch_k[-2] and stoch_d[-1] > stoch_k[-1]:  # Stochastic crossover confirmation
                self.sell(size=position_size)
                print(f"🌙 Moon Dev Sell Signal: Continuation Downtrend | Price: {price} | Stochastic: {stoch_k}, {stoch_d}")

        # Entry logic for breakout pattern (uptrend reversal)
        if price < uptrend_line and stoch_k > self.stoch_overbought and stoch_d > self.stoch_overbought:
            if stoch_d[-2] < stoch_k[-2] and stoch_d[-1] > stoch_k[-1]:  # Stochastic crossover confirmation
                self.sell(size=position_size)
                print(f"🌙 Moon Dev Sell Signal: Breakout Uptrend Reversal | Price: {price} | Stochastic: {stoch_k}, {stoch_d}")

        # Entry logic for breakout pattern (downtrend reversal)
        elif price > downtrend_line and stoch_k < self.stoch_oversold and stoch_d < self.stoch_oversold:
            if stoch_k[-2] < stoch_d[-2] and stoch_k[-1] > stoch_d[-1]:  # Stochastic crossover confirmation
                self.buy(size=position_size)
                print(f"🌙 Moon Dev Buy Signal: Breakout Downtrend Reversal | Price: {price} | Stochastic: {stoch_k}, {stoch_d}")

        # Exit logic
        for trade in self.trades:
            if trade.is_long:
                # Take profit: previous swing high
                take_profit = self.data.High[-20:].max()
                # Stop loss: recent swing low
                stop_loss = self.data.Low[-20:].min()
                trade.sl = stop_loss
                trade.tp = take_profit
            elif trade.is_short:
                # Take profit: previous swing low
                take_profit = self.data.Low[-20:].min()
                # Stop loss: recent swing high
                stop_loss = self.data.High[-20:].max()
                trade.sl = stop_loss
                trade.tp = take_profit

# Load and preprocess data
data_path = "/Users/md/Dropbox/dev/github/moon-dev-ai-agents-for-trading/src/data/rbi/BTC-USD-15m.csv"
data = preprocess_data(data_path)

# Initialize and run backtest
bt = Backtest(data, MomentumRejection, cash=1_000_000, commission=0.002)
stats = bt.run()
print(stats)
print(stats._strategy)

# Plot initial performance
bt.plot()

# Optimize parameters
optimization_results = bt.optimize(
    stoch_k_period=range(10, 20, 2),
    stoch_d_period=range(2, 5, 1),
    stoch_overbought=range(70, 90, 5),
    stoch_oversold=range(10, 30, 5),
    risk_per_trade=[0.01, 0.02],
    risk_reward_ratio=[2, 3],
    maximize='Return [%]'
)
print(optimization_results)

# Plot optimized performance
bt.plot()